{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Keras是一个基于Python编写的高层神经网络API，凭借用户友好性、模块化以及易扩展等有点大受好评，考虑到Keras的优良特性以及它的受欢迎程度，TensorFlow2.0中将Keras的代码吸收了进来，化身为tf.keras模块供用户使用。\n",
    "\n",
    "使用tf.keras提供的高层API，可以轻松得完成建模三部曲——模型构建、训练、评估等工作。下面我们分别来说说如何使用tf.keras完成这三部曲。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1 模型构建"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们知道，神经网络模型就是层的堆叠，tf.keras提供的Sequential类对象就是层容器，可以轻松实现对层的堆叠，创建网络模型。用Sequential创建一个全连接网络模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras  # 为方便使用，keras一般单独导入\n",
    "from tensorflow.keras import layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential()\n",
    "# 往模型中添加一个有64个神经元组成的层，激活函数为relu:\n",
    "model.add(layers.Dense(64, activation='relu'))\n",
    "# 再添加一个:\n",
    "model.add(layers.Dense(64, activation='relu'))\n",
    "# 添加一个有10个神经元的softmax层作为输出层:\n",
    "model.add(layers.Dense(10, activation='softmax'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "也可以在使用Sequential实例化模型时，通过传入由层组成的列表来添加层。我们换一种方式实现上面的模型构建过程,两种方式是完全等效的："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential([\n",
    "    layers.Dense(64, activation='relu'),\n",
    "    layers.Dense(64, activation='relu'),\n",
    "    layers.Dense(10, activation='softmax')]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "你看，用tf.keras创建一个模型，就是这么简单，只需要往Sequential中传入一个个tf.keras.layers定义的层就好了。进一步的，我们研究一下tf.keras.layers怎么个性化地创建层。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义神经网络层通过tf.keras.layers模块中的Dense类实现，Dense类构造参数如下：\n",
    "- units：指定神经元个数，必须是一个正整数。\n",
    "- activation：激活函数，可以是可以是一个可调用对象或标识一个对象的字符串\n",
    "- use_bias：布尔型，是否使用是否使用偏置项\n",
    "- kernel_initializer和bias_initializer：权值、偏置初始化方法,可以是一个可调用对象或标识一个对象的字符串\n",
    "- kernel_regularizer和bias_regularizer：对权值、偏置进行正则化的方法，可以是一个可调用对象或标识一个对象的字符串\n",
    "- activity_regularizer：对层的输出进行正则化的方法，可以是一个可调用对象或标识一个对象的字符串\n",
    "- kernel_constraint和bias_constraint：对权值矩阵、偏置矩阵的约束方法，可以是一个可调用对象或标识一个对象的字符串"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.layers.core.Dense at 0x7f486247abd0>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 有64个神经元，激活函数为sigmoid的层\n",
    "layers.Dense(64, activation='sigmoid')\n",
    "# 或者:\n",
    "layers.Dense(64, activation=tf.keras.activations.sigmoid)\n",
    "\n",
    "# 对权值矩阵进行正则化:\n",
    "layers.Dense(64, kernel_regularizer=tf.keras.regularizers.l1(0.01))\n",
    "\n",
    "# 对偏置向量进行正则化:\n",
    "layers.Dense(64, bias_regularizer=tf.keras.regularizers.l2(0.01))\n",
    "\n",
    "# 指定权值随机正交初始化:\n",
    "layers.Dense(64, kernel_initializer='orthogonal')\n",
    "\n",
    "# 指定偏置为常数:\n",
    "layers.Dense(64, bias_initializer=tf.keras.initializers.Constant(2.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2 训练模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.1 配置：compile()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "建立好模型之后，接下来当然是要进行训练模型了。不过，在训练前还需要做一些配置工作，例如指定优化器、损失函数、评估指标等，这些配置参数的过程一般通过tf.keras.Model.compile方法进行，先来熟悉一下tf.keras.Model.compile方法的三个常用参数：\n",
    "- optimizer：tf.keras.optimizers模块中的优化器**实例化对象**，例如 tf.keras.optimizers.Adam或 tf.keras.optimizers.SGD的实例化对象，当然也可以使用字符串来指代优化器，例如'adam'和'sgd'。\n",
    "- loss：损失函数，例如交叉熵、均方差等，通常是tf.keras.losses模块中定义的可调用对象，也可以用用于指代损失函数的字符串。\n",
    "- metrics：元素为评估方法的list，通常是定义在tf.keras.metrics模块中定义的可调用对象，也可以用于指代评估方法的字符串。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在知道怎么配置模型训练参数后，就可以根据实际应用情况合理选择优化器、损失函数、评估方法等："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 回归模型\n",
    "model.compile(optimizer=tf.keras.optimizers.Adam(0.01),  # 指定优化器，学习率为0.01\n",
    "              loss='mse',       # 指定均方差作为损失函数\n",
    "              metrics=['mae'])  # 添加绝对值误差作为评估方法\n",
    "\n",
    "# 分类模型\n",
    "model.compile(optimizer=tf.keras.optimizers.RMSprop(0.01),\n",
    "              loss=tf.keras.losses.CategoricalCrossentropy(),  # 分类模型多用交叉熵作为损失函数\n",
    "              metrics=[tf.keras.metrics.CategoricalAccuracy()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过compile()配置好模型后，就可以开始训练了。tf.keras中提供了fit()方法对模型进行训练，先来看看fit()方法的主要参数：\n",
    "- x和y：训练数据和目标数据\n",
    "- epochs：训练周期数，每一个周期都是对训练数据集的一次完整迭代\n",
    "- batch_size：簇的大小，一般在数据集是numpy数组类型时使用\n",
    "- validation_data：验证数据集，模型训练时，如果你想通过一个额外的验证数据集来监测模型的性能变换，就可以通过这个参数传入验证数据集\n",
    "- verbose：日志显示方式，verbose=0为不在标准输出流输出日志信息,verbose=1为输出进度条记录，verbose=2为每个epoch输出一行记录\n",
    "- callbacks：回调方法组成的列表，一般是定义在tf.keras.callbacks中的方法\n",
    "- validation_split：从训练数据集抽取部分数据作为验证数据集的比例，是一个0到1之间的浮点数。这一参数在输入数据为dataset对象、生成器、keras.utils.Sequence对象是无效。\n",
    "- shuffle：是否在每一个周期开始前打乱数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面分别说说如何使用fit()方法结合numpy数据和tf.data.Dataset数据进行模型训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 1000 samples\n",
      "Epoch 1/10\n",
      "1000/1000 [==============================] - 1s 554us/sample - loss: 206.2688 - categorical_accuracy: 0.1050\n",
      "Epoch 2/10\n",
      "1000/1000 [==============================] - 0s 34us/sample - loss: 911.8347 - categorical_accuracy: 0.0990\n",
      "Epoch 3/10\n",
      "1000/1000 [==============================] - 0s 30us/sample - loss: 1879.7505 - categorical_accuracy: 0.0980\n",
      "Epoch 4/10\n",
      "1000/1000 [==============================] - 0s 28us/sample - loss: 3141.3959 - categorical_accuracy: 0.0940\n",
      "Epoch 5/10\n",
      "1000/1000 [==============================] - 0s 36us/sample - loss: 4673.7791 - categorical_accuracy: 0.1010\n",
      "Epoch 6/10\n",
      "1000/1000 [==============================] - 0s 36us/sample - loss: 6526.8757 - categorical_accuracy: 0.0960\n",
      "Epoch 7/10\n",
      "1000/1000 [==============================] - 0s 31us/sample - loss: 8571.8533 - categorical_accuracy: 0.1020\n",
      "Epoch 8/10\n",
      "1000/1000 [==============================] - 0s 33us/sample - loss: 11070.1039 - categorical_accuracy: 0.0970\n",
      "Epoch 9/10\n",
      "1000/1000 [==============================] - 0s 36us/sample - loss: 13533.4661 - categorical_accuracy: 0.1050\n",
      "Epoch 10/10\n",
      "1000/1000 [==============================] - 0s 25us/sample - loss: 17259.2291 - categorical_accuracy: 0.1000\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f74a4755650>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "data = np.random.random((1000, 32))\n",
    "labels = np.random.random((1000, 10))\n",
    "\n",
    "model.fit(data, labels, epochs=10, batch_size=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如何使用验证数据集的话，可以这样："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 1000 samples, validate on 100 samples\n",
      "Epoch 1/10\n",
      "1000/1000 [==============================] - 0s 34us/sample - loss: 67219.8359 - categorical_accuracy: 0.0960 - val_loss: 55306.6777 - val_categorical_accuracy: 0.1000\n",
      "Epoch 2/10\n",
      "1000/1000 [==============================] - 0s 32us/sample - loss: 73732.5724 - categorical_accuracy: 0.0920 - val_loss: 89920.2088 - val_categorical_accuracy: 0.1100\n",
      "Epoch 3/10\n",
      "1000/1000 [==============================] - 0s 50us/sample - loss: 79956.1480 - categorical_accuracy: 0.1020 - val_loss: 101092.6750 - val_categorical_accuracy: 0.1000\n",
      "Epoch 4/10\n",
      "1000/1000 [==============================] - 0s 36us/sample - loss: 84322.9844 - categorical_accuracy: 0.0970 - val_loss: 117610.5700 - val_categorical_accuracy: 0.1000\n",
      "Epoch 5/10\n",
      "1000/1000 [==============================] - 0s 38us/sample - loss: 91992.0751 - categorical_accuracy: 0.1130 - val_loss: 94200.0838 - val_categorical_accuracy: 0.1000\n",
      "Epoch 6/10\n",
      "1000/1000 [==============================] - 0s 38us/sample - loss: 97189.2044 - categorical_accuracy: 0.0910 - val_loss: 89020.5294 - val_categorical_accuracy: 0.1100\n",
      "Epoch 7/10\n",
      "1000/1000 [==============================] - 0s 33us/sample - loss: 107109.9905 - categorical_accuracy: 0.0930 - val_loss: 102350.4259 - val_categorical_accuracy: 0.1200\n",
      "Epoch 8/10\n",
      "1000/1000 [==============================] - 0s 41us/sample - loss: 114450.2496 - categorical_accuracy: 0.1010 - val_loss: 102719.3653 - val_categorical_accuracy: 0.1100\n",
      "Epoch 9/10\n",
      "1000/1000 [==============================] - 0s 41us/sample - loss: 124694.8415 - categorical_accuracy: 0.0950 - val_loss: 142269.8362 - val_categorical_accuracy: 0.1100\n",
      "Epoch 10/10\n",
      "1000/1000 [==============================] - 0s 44us/sample - loss: 131952.7791 - categorical_accuracy: 0.0800 - val_loss: 158925.8294 - val_categorical_accuracy: 0.0900\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7f749c548810>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "data = np.random.random((1000, 32))\n",
    "labels = np.random.random((1000, 10))\n",
    "\n",
    "val_data = np.random.random((100, 32))\n",
    "val_labels = np.random.random((100, 10))\n",
    "\n",
    "model.fit(data, labels, epochs=10, batch_size=32,\n",
    "          validation_data=(val_data, val_labels))  # 验证数据集以元组的形式传入"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3 评估与预测"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "是骡子是马，拉出来溜溜就知道了，训练好的模型性能如何，评估测试一下就知道了。可以使用模型自带的evaluate()方法和predict()方法对模型进行评估和预测。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "1000/1 [================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 17us/sample - loss: 161163.7180 - categorical_accuracy: 0.0930\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[153591.27975, 0.093]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 如果是numpy数据，可以这么使用\n",
    "data = np.random.random((1000, 32))\n",
    "labels = np.random.random((1000, 10))\n",
    "\n",
    "model.evaluate(data, labels, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32/32 [==============================] - 0s 579us/step - loss: 153946.2378 - categorical_accuracy: 0.0930\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[153946.23779296875, 0.093]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 如果数Dataset对象，可以这么使用\n",
    "dataset = tf.data.Dataset.from_tensor_slices((data, labels))\n",
    "dataset = dataset.batch(32)\n",
    "\n",
    "model.evaluate(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用predict()方法进行预测："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000, 10)\n"
     ]
    }
   ],
   "source": [
    "# numpy数据\n",
    "result = model.predict(data, batch_size=32)\n",
    "print(result.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000, 10)\n"
     ]
    }
   ],
   "source": [
    "# dataset数据\n",
    "result = model.predict(dataset)\n",
    "print(result.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow_gpu",
   "language": "python",
   "name": "tensorflow_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
